# # Anonymised
import math
import sys
from argparse import Namespace
from typing import Tuple

import torch
from datasets import get_dataset
from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
import backbone.Model_Avg as Model_Avg

from utils.loggers import *
from utils.status import ProgressBar
from utils.batch_norm import bn_track_stats

try:
    import wandb
except ImportError:
    wandb = None


def compute_class_means(model, dataset):
        """
        Computes a vector representing mean features for each class.
        """
        # This function caches class means
        transform = dataset.get_normalization_transform()
        class_means = []
        buf_data = model.buffer.get_all_data(transform)
        examples, labels = buf_data[0], buf_data[1]
        for _y in labels.unique().to('cpu'):
            x_buf = torch.stack(
                [examples[i]
                 for i in range(0, len(examples))
                 if labels[i].cpu() == _y]
            ).to(model.device)
            with bn_track_stats(model, False):
                allt = None
                while len(x_buf):
                    batch = x_buf[:model.args.batch_size]
                    x_buf = x_buf[model.args.batch_size:]
                    feats = model.net(batch, returnt='features').mean(0)
                    if allt is None:
                        allt = feats
                    else:
                        allt += feats
                        allt /= 2
                class_means.append(allt.flatten())
        return torch.stack(class_means)


def nmc_out(class_means, model, x):
    feats = model.net(x, returnt='features')
    feats = feats.view(feats.size(0), -1)
    feats = feats.unsqueeze(1)

    pred = (class_means.unsqueeze(0) - feats).pow(2).sum(2)
    return -pred


def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int) -> None:
    """
    Given the output tensor, the dataset at hand and the current task,
    masks the former by setting the responses for the other tasks at -inf.
    It is used to obtain the results for the task-il setting.
    :param outputs: the output tensor
    :param dataset: the continual dataset
    :param k: the task index
    """
    outputs[:, 0:k * dataset.N_CLASSES_PER_TASK] = -float('inf')
    outputs[:, (k + 1) * dataset.N_CLASSES_PER_TASK:
               dataset.N_TASKS * dataset.N_CLASSES_PER_TASK] = -float('inf')


def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tuple[list, list]:
    """
    Evaluates the accuracy of the model for each past task.
    :param model: the model to be evaluated
    :param dataset: the continual dataset at hand
    :return: a tuple of lists, containing the class-il
             and task-il accuracy for each task
    """
    status = model.net.training
    model.net.eval()

    if model.nmc:
        with torch.no_grad():
            class_means = compute_class_means(model, dataset)
            class_means = class_means.squeeze()

    accs, accs_mask_classes = [], []
    for k, test_loader in enumerate(dataset.test_loaders):
        if last and k < len(dataset.test_loaders) - 1:
            continue
        correct, correct_mask_classes, total = 0.0, 0.0, 0.0
        for data in test_loader:
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(model.device), labels.to(model.device)
                if 'class-il' not in model.COMPATIBILITY:
                    outputs = model(inputs, k)
                else:
                    if model.nmc:
                        outputs = nmc_out(class_means, model, inputs)
                    else:
                        outputs = model(inputs)

                _, pred = torch.max(outputs.data, 1)
                correct += torch.sum(pred == labels).item()
                total += labels.shape[0]

                if dataset.SETTING == 'class-il':
                    mask_classes(outputs, dataset, k)
                    _, pred = torch.max(outputs.data, 1)
                    correct_mask_classes += torch.sum(pred == labels).item()

        accs.append(correct / total * 100
                    if 'class-il' in model.COMPATIBILITY else 0)
        accs_mask_classes.append(correct_mask_classes / total * 100)

    model.net.train(status)
    return accs, accs_mask_classes


def model_avg_eval(model, dataset, avg_net, nmc, t):
    temp_net_store = model.net
    avg_net.eval()
    model.net = avg_net.module
    accs = evaluate(model, dataset)
    mean_acc = np.mean(accs, axis=1)
    print("\nmodel avg eval:"+"      [Class-IL]: "+"{:.4}".format(mean_acc[0].item())+" %      [Task-IL]: "+"{:.4}".format(mean_acc[1].item())+" %")
    #print_mean_accuracy(mean_acc, t + 1, dataset.SETTING)
    if nmc:
        print("model avg ",  end='')
        nmc_eval(model, dataset, t)
    model.net = temp_net_store


def nmc_eval(model, dataset, t):
    model.nmc = True
    accs = evaluate(model, dataset)
    mean_acc = np.mean(accs, axis=1)
    print("nmc eval:"+"      [Class-IL]: "+"{:.4}".format(mean_acc[0].item())+" %      [Task-IL]: "+"{:.4}".format(mean_acc[1].item())+" %")
    #print_mean_accuracy(mean_acc, t + 1, dataset.SETTING)
    model.nmc = False

def train(model: ContinualModel, dataset: ContinualDataset,
          args: Namespace) -> None:
    """
    The training process, including evaluations and loggers.
    :param model: the module to be trained
    :param dataset: the continual dataset at hand
    :param args: the arguments of the current execution
    """
    print(args)

    if not args.nowand:
        assert wandb is not None, "Wandb not installed, please install it or run without wandb"
        wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args))
        args.wandb_url = wandb.run.get_url()

    model.net.to(model.device)
    results, results_mask_classes = [], []

    if args.decorr is not None:
        model.net.decorr_to(model.device)
    
    #if args.nmc:
    #    model.nmc = True
    #else:
    ##    model.nmc = False
    model.nmc = False 
 
    if not args.disable_log:
        logger = Logger(dataset.SETTING, dataset.NAME, model.NAME)

    progress_bar = ProgressBar(verbose=not args.non_verbose)

    if not args.ignore_other_metrics:
        dataset_copy = get_dataset(args)
        for t in range(dataset.N_TASKS):
            model.net.train()
            _, _ = dataset_copy.get_data_loaders()
        if model.NAME != 'icarl' and model.NAME != 'pnn' and not args.nmc:
            random_results_class, random_results_task = evaluate(model, dataset_copy)

    print(file=sys.stderr)
    for t in range(dataset.N_TASKS):
        model.net.train()
        train_loader, test_loader = dataset.get_data_loaders()
        if hasattr(model, 'begin_task'):
            model.begin_task(dataset)
        if t and not args.ignore_other_metrics:
            accs = evaluate(model, dataset, last=True)
            results[t-1] = results[t-1] + accs[0]
            if dataset.SETTING == 'class-il':
                results_mask_classes[t-1] = results_mask_classes[t-1] + accs[1]
        
        if (args.decorr is not None) and args.decorr < 0:
            if abs(args.decorr) in [7, 9, 10]:
                model.net.reset_decorr_weights()
            if (abs(args.decorr) not in [5, 6, 8]) or t == 0:
                if abs(args.decorr) in [8, 9, 10]:
                    dataset_copy = get_dataset(args)
                    for temp_t in range(dataset.N_TASKS):
                        if abs(args.decorr) == 10 and temp_t > t:
                            break
                        temp_train_loader, _ = dataset_copy.get_data_loaders()
                        model.net.update_decorr_weights(temp_train_loader.dataset)
                else:
                    model.net.update_decorr_weights(train_loader.dataset)
        
        scheduler = dataset.get_scheduler(model, args)
        for epoch in range(model.args.n_epochs):
            if args.model == 'joint':
                continue
            for i, data in enumerate(train_loader):
                if args.debug_mode and i > 3:
                    break
                if hasattr(dataset.train_loader.dataset, 'logits'):
                    inputs, labels, not_aug_inputs, logits = data
                    inputs = inputs.to(model.device)
                    labels = labels.to(model.device)
                    not_aug_inputs = not_aug_inputs.to(model.device)
                    logits = logits.to(model.device)
                    loss = model.meta_observe(inputs, labels, not_aug_inputs, logits)
                else:
                    inputs, labels, not_aug_inputs = data
                    inputs, labels = inputs.to(model.device), labels.to(
                        model.device)
                    not_aug_inputs = not_aug_inputs.to(model.device)
                    loss = model.meta_observe(inputs, labels, not_aug_inputs)
                assert not math.isnan(loss)
                progress_bar.prog(i, len(train_loader), epoch, t, loss)
                #if (args.decorr is not None) and epoch < 3 and abs(args.decorr) == 5:
                #    model.net.inc_update_decorr_weights(inputs)

            if scheduler is not None:
                scheduler.step()

        if hasattr(model, 'end_task'):
            model.end_task(dataset)

        if args.model_avg or (args.EMA is not None):
            if t == 0:
                if args.model_avg:
                    avg_net = Model_Avg.AvgOverChunks(model.net)
                if args.EMA is not None:
                    avg_net = Model_Avg.EMA(model.net, decay=args.EMA) 
            else:
                avg_net.update(model.net)
            model_avg_eval(model, dataset, avg_net, args.nmc, t)

        if args.nmc:
            nmc_eval(model, dataset, t)
                        
        accs = evaluate(model, dataset) 

        if (args.decorr is not None) and args.decorr > 0 and t != dataset.N_TASKS-1:
            if abs(args.decorr) in [7, 9, 10]:
                model.net.reset_decorr_weights()
            if (abs(args.decorr) not in [5, 6, 8]) or t == 0:
                if abs(args.decorr) in [8, 9, 10]:
                    dataset_copy = get_dataset(args)
                    for temp_t in range(dataset.N_TASKS):
                        if abs(args.decorr) == 10 and temp_t > t:
                            break
                        temp_train_loader, _ = dataset_copy.get_data_loaders()
                        model.net.update_decorr_weights(temp_train_loader.dataset)
                else:
                    model.net.update_decorr_weights(train_loader.dataset)
         
        results.append(accs[0])
        results_mask_classes.append(accs[1])

        mean_acc = np.mean(accs, axis=1)
        print_mean_accuracy(mean_acc, t + 1, dataset.SETTING)

        if not args.disable_log:
            logger.log(mean_acc)
            logger.log_fullacc(accs)

        if not args.nowand:
            d2={'RESULT_class_mean_accs': mean_acc[0], 'RESULT_task_mean_accs': mean_acc[1],
                **{f'RESULT_class_acc_{i}': a for i, a in enumerate(accs[0])},
                **{f'RESULT_task_acc_{i}': a for i, a in enumerate(accs[1])}}

            wandb.log(d2)



    if not args.disable_log and not args.ignore_other_metrics:
        logger.add_bwt(results, results_mask_classes)
        logger.add_forgetting(results, results_mask_classes)
        if model.NAME != 'icarl' and model.NAME != 'pnn'and not args.nmc:
            logger.add_fwt(results, random_results_class,
                    results_mask_classes, random_results_task)

    if not args.disable_log:
        logger.write(vars(args))
        if not args.nowand:
            d = logger.dump()
            d['wandb_url'] = wandb.run.get_url()
            wandb.log(d)

    if not args.nowand:
        wandb.finish()
